import argparse
import pathlib
import json
import os
from tqdm import tqdm

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

def load_aokvqa(aokvqa_dir, split, version='v1p0'):
    assert split in ['train', 'val', 'test', 'test_w_ans']
    dataset = json.load(open(
        os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
    ))
    return dataset


def map_to_choices(dataset, predictions, device='cpu'):
    if isinstance(dataset, list):
        dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }

    if all([p in dataset[q]['choices'] for q, p in predictions.items()]):
        return predictions

    model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d')
    model.to(device)
    for q in tqdm(predictions.keys()):
        choices = dataset[q]['choices']
        if predictions[q] not in choices:
            choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True)
            a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item()
            predictions[q] = choices[a_idx]

    return predictions


if __name__ == '__main__':

    dataset = load_aokvqa('/home/test/yxl/MCoT/data/aokvqa', 'val')
    with open('/home/test/yxl/MCoT/aokvqa/results/qwen-test/SC.json', 'r') as f:
        predictions = json.load(f)
    predictions = map_to_choices(dataset, predictions)

    with open('/home/test/yxl/MCoT/aokvqa/results/qwen-test/SC_val.json', 'w') as f:
        json.dump(predictions, f)